package eshore.cn.it.classification;

import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectOutputStream;
import java.util.logging.Logger;

import com.aliasi.classify.Classification;
import com.aliasi.classify.Classified;
import com.aliasi.classify.DynamicLMClassifier;
import com.aliasi.lm.NGramProcessLM;
import com.aliasi.util.Files;

import eshore.cn.it.configuration.ClassModelConfiguration;

/***
 * NGramClassierTrainer 是利用 NGram 特征提取的方式训练分类模型
 * 最后保存模型到指定的位置。
 * 1、训练模型（利用全新的参数）
 * 2、在已有的模型上面继续更新训练
 * 模型一定要通过NGramClassierTest测试之后，达到一定的精度（目前还没有写十则交叉验证，只有二择交叉验证）
 * 才能利用全部样本训练模型，投入实际生产环境。
 * @author clebeg	2015-04-14 10:40
 */
public class NGramClassierTrainer {
	private static Logger logger = 
			Logger.getLogger(NGramClassierTrainer.class.getName());
	/**
	 * ngram特征提取，当前字符与前几个字符相关
	 * 已模型训练时候的参数设置为准，默认设置为3
	 * 太大容易出现过拟合。
	 */
	private int ngramSize = 3;
	public int getNgramSize() {
		return ngramSize;
	}
	public void setNgramSize(int ngramSize) {
		this.ngramSize = ngramSize;
	}
	
	private ClassModelConfiguration modelConfig;
	public ClassModelConfiguration getModelConfig() {
		return modelConfig;
	}
	public void setModelConfig(ClassModelConfiguration modelConfig) {
		this.modelConfig = modelConfig;
	}
	
	//基于NGram提取特征之后的分类器
	private  DynamicLMClassifier<NGramProcessLM> classifier;
	
	
	public void trainModel() throws IOException {
		classifier
			= DynamicLMClassifier.createNGramProcess(this.getModelConfig().getCategories(), 
					this.getNgramSize());
		
		for(int i = 0; i < this.getModelConfig().getCategories().length; ++i) {
            File classDir = new File(this.getModelConfig().getTrainRootDir(), 
            		this.getModelConfig().getCategories()[i]);
            
            if (!classDir.isDirectory()) {
                String msg = "无法找到包含训练数据的路径 ="
                    + classDir
                    + "\n你是否未在训练root目录建立  "
                    + this.getModelConfig().getCategories().length
                    + "类别?";
                logger.severe(msg); // in case exception gets lost in shell
                throw new IllegalArgumentException(msg);
            }
            
            String[] trainingFiles = classDir.list();
            for (int j = 0; j < trainingFiles.length; ++j) {
                File file = new File(classDir, trainingFiles[j]);
                String text = Files.readFromFile(file, "GBK");
                System.out.println("正在训练   -> " + this.getModelConfig().getCategories()[i] + "/" + trainingFiles[j]);
                Classification classification
                    = new Classification(this.getModelConfig().getCategories()[i]);
                Classified<CharSequence> classified
                    = new Classified<CharSequence>(text, classification);
                classifier.handle(classified);
            }
        }
		File modelFile = new File(this.getModelConfig().getModelFile());
		System.out.println("所有类别均训练完成.开始保存模型到 " + modelFile.getAbsolutePath());
		if (modelFile.exists() == false) {
			logger.info("指定模型文件不存在，将自动建立上层目录，并建立文件...");
			modelFile.getParentFile().mkdirs();
		} else {
			logger.warning("指定模型文件已经存在，将覆盖...");
		}
		
		ObjectOutputStream os = new ObjectOutputStream(new FileOutputStream(  
				modelFile));  
        classifier.compileTo(os);  
        os.close(); 
        
        System.out.println("模型保持成功！");
	}
	
}
